Sparse neural networks have garnered attention due to their theoretical promise of lowered computational demands and memory savings. However, to this date, the theoretical gains have largely failed to materialize due to the lack of hardware support for this kind of models. In this work, we explore the idea of neuron recycling which is inspired by pruning - a method often employed to induce sparsity in neural networks. We also present lessons we have learned along the way.
Introduction
Pruning is a well-established technique used to sparsify neural networks. It relies on the fact that typically a large part of a trained neural network can be masked without impacting the accuracy of the network, albeit often requiring additional fine-tuning in order to regain some lost performance. Despite multiple works proposing various neuron-selection criteria for pruning, magnitude-based pruning remains a viable option. The Lottery Ticket Hypothesis is a major finding on the way to explain how the initialization impacts neural networks. The main point of the LTH is that through iterative pruning, performant subnetworks depending on the initialization can be found in neural networks. Those well-initialized network fragments are the namesake of LTH (the “lottery tickets”). Although some notions of the original LTH paper have been challenged, it has remained the subject of active research and a motivation for our work. By combining the two ideas (pruning and LTH) we arrive at a new potential technique for raising neural network performance. If we are able to remove parts of the network without hurting the performance (pruning) and the fitness of a part of a network is determined at initialization, perhaps we could re-initialize the unnecessary network parts (i. e. draw more “lottery tickets”), leading to a better-performing network.
Preliminaries
Before we move to the presentation of our experiments and findings, let’s first discuss the training setup, define key terminology, and go over the basics.
Model and training setup
In our project, we are focusing on the Transformer, since it’s a major architecture across different domains. For the specific type of the model, we are working on encoder-only BERT . Taking into consideration available computational resources and expected iteration time (we wanted to try as many options as possible), we decided to opt for the BERT Medium configuration (with \(d_\text{model}=512\) and \(8\) attention heads). We focus on the feed-forward layer, because it is the most computationally demanding part of commonly-used transformer models and, in large models, it contains the majority of the parameters. At the same time, the amount of research focusing on the attention mechanism is overwhelming, suggesting that the feed-forward layer is a relatively unexplored area.
We trained the model for \(80{,}000\) steps (around compute-optimal number of train samples for this size of model) , with Adam, using batch size of \(256\) and learning rate of \(0.001\).
In the following part of this post, we will often use the terms neuron and magnitude. Below are the definitions we employ.
Neuron
In the Transformer, feed-forward layer consists of two linear layers, with a nonlinearity in between. The first layer maps the input vector from \(d_\text{model}\) to \(d_\text{ff}\) dimension, and the second one from \(d_\text{ff}\) back to \(d_\text{model}\). Typically, \(d_\text{ff}\) is four times greater than \(d_\text{model}\). By neuron, we will understand all weights interacting with a particular coordinate in the \(\mathbb{R}^{d_\text{ff}}\) activation vector. In thetorchimplementation, a neuron’s weights are the parameters in a given row of the first feed-forward matrix and in the corresponding column in the second one.Magnitude
To calculate magnitude of a weight, we will use its absolute value. As the magnitude of a neuron we will use value of the expression \(M=\sqrt{\sum{w_{in}^2} \sum{w_{out}^2}}\).
Pruning
Pruning is a technique used to induce sparsity and decrease the parameter count in a neural network. In simple terms, it means deleting the least important neurons (structured pruning) or weights (unstructured pruning). A typical implementation realizes this by either multiplying the output of the deleted neurons by 0 or setting the weights of the neuron to 0. A widely-used proxy for the importance of a neuron or weight is its magnitude.
Below we present a plot with loss curves of the model gradually pruned at the FF layer, starting in step \(10{,}000\), such that the layer is completely masked in the end of the training. As a comparison, we also add regular model and the one without feed-forward layer.
Interestingly, the effect of pruning can’t be visible for a significant fraction of the training time. It’s also worth noting that in the end the model without FF Layer performs slightly better than the pruned one. This is because in the first case, Attention was trained to adjust from the very beginning of the training.
The goal
The end-goal of the project was to create a method that would allow us to make better use of the parameters in the feed-forward layer. In this context, a natural question arises - against what baseline should our results be compared? To answer this question, we trained the model with differing dimensionalities of the feed-forward layer. The results are presented below. The true BERT Medium configuration has \(d_\text{ff}=2048\), and, as expected, the model’s performance drops when the \(d_\text{ff}\) is decreased and increases when the \(d_\text{ff}\) is increased.
Understanding neuron magnitudes
One of the key inspirations for our work was structured pruning, where neuron magnitude is often chosen as the measure of significance. We were interested in how this metric evolves during the training process. At first, we thought a histogram of neuron magnitudes would exhibit a normal distribution. However, our experiments showed something different. The following graph shows evolution of neuron magnitudes throughout the training process.
In the early stages of training, the neurons split into two groups, one featuring much lower magnitudes than the other. This finding opens up a multitude of discussion topics. It can be speculated that the neurons belonging to the group with smaller magnitudes potentially don’t hold much importance and can be pruned freely. However, it’s also possible that these neurons, though small, play a critical role in specific tasks.
This phenomenon is not limited to the first layer of the network. We have observed it in all layers, apart from the last one, as shown in the following plot.
After examining these experiments, we were trying to understand why in the early layers we observed two distinct groups of neurons, categorized by their magnitudes. One possible explanation is that certain parts of the network receive a smaller signal and are slower to improve in training. We designed an experiment to check that. We periodically froze all parts of the network except for the feed-forward component and continued to train it for several batches of data. We hypothesized that in this scenario, weaker neurons might catch up, resulting in a more even distribution. We called this procedure overtraining feed-forward layer. It’s important to note that this approach is impractical and computationally heavy, but we wanted to use it for the purpose of illustration. The results are depicted in the following plot.
We can see that the group of “weaker” neurons has moved to the right after performing additional training of the FF part. However, neurons still form two distinct groups: overtraining the whole layer is not enough for the weaker ones to catch up. In the next experiment, we have examined the scenario of retraining only small magnitude neurons, only large magnitude neurons and random subsets. How does it affect the performance? The results are depicted on the following plot.
Overtraining only the smallest neurons yields the best results when compared to reinforcing high-magnitude ones. Notably, overtraining the small ones gives similar gains in performance to working on the entire layer! Contrarily, ampifying the highest ones gives gains similar to no overtraining at all. This provides a compelling argument in favor of our technique.
Magnitudes in openly available pretrained models
So far, we have performed a series of experiments on relatively small architectures. We were curious to see how our observations would translate to well-established, large-scale foundation models like BERT Large or T5.
There is a clear difference between the plots above. Magnitudes in T5 seem similar to those in our smaller models, while BERT presents a more balanced distribution. What could account for these variations? We discovered that the use of weight decay in BERT Large plays a significant role. This simple but widely used technique has an important impact on the distribution.
These findings support the idea of exploring neuron recycling and offer a good foundation for further experiments. In the next sections, we will present results on this topic and share our insights.
Recycling
The central part of our work was a method we called neuron recycling. The consists of three phases, repeated periodically: training, selection and reinitialization.

- In the training phase, the model is trained to predict masked tokens (masked language modelling).
- In the selection phase, the least important neurons are determined, where the baseline criterion is neuron magnitude
- In the reinitialization phase, new weights are assigned to neurons.
Although this procedure is conceptually simple, it allows for many degrees of freedom. Here are some choices that can be made:
- The number of training steps before consecutive selection / reinitialization phases
- The percentage of recycled neurons
- Selection / reinitialization strategies
After examining the pruning literature, we have found that the simple magnitude-based approach works best in most cases
As you can see, removing low magnitude neurons hurts the model the least, and removing high magnitude ones cases the largest loss. This is a good argument that this criterion correlates well with neuron significance.
Baseline recycling
The most straightforward reinitialization scheme is to sample the weights of the reinitialized neurons from the initial distribution. After examining the performance of this solution, we could not see any difference between recycling and vanilla training.
As a sanity check, we have examined the histogram presenting the number of times each neuron was recycled, discovering that the same small subset of neurons was being reinitialized over and over during training.
As we have seen , on average magnitude of neurons grows throughout the training. Therefore, sampling from the initial distribution will cause the reycycled neurons to have even lower magnitudes. As an effect, they are unable to catch up to before another selection phase. Thus, the recycled neurons are caught up in a vicious cycle in which they are always recycled before achieving high magnitude.
Immunity
To address the problem we observed in the previous approach, we tried another strategy - recycling with immunity. The idea here is to encourage diverse recycling by making each recycled neuron immune to further recycling for some predefined number of steps. We hypothesized that a reinitialized neuron needs some time to grow, which was not possible in the initial setting. The following plot illustrates that immunity prevents the recycled neurons from being catched in a vicious cycle.
Higher number of immunity rounds (i.e. number of selection phases when a newly recycled neuron can’t be chosen) causes more neurons to be reinitialized at least once. Unfortunately, this eventually causes well-behaving parts of the network to be chosen for recycling. As an effect, the performance drops.
Modifying reinitialization distribution
As we have pointed out before, magnitude and weight distribution drifts away from the initial distribution as the training progresses. However, during our initial attempts, we initialized the weights sampling from the initial distribution. To fix this issue, we decided to try out another weight sampling technique. In this approach we used the normal distribution with mean and standard deviation equal to the mean and standard deviation of all the weights in the respective matrix. This approach, like immunity, eliminated the vicious cycle problem.
However, this process introduced a lot of noise with adverse effect on the model’s loss.
Copying existing neurons
Inspired by the Gopher paper, we have also tried to copy existing neurons while adding some noise to their weights. This approach was also unsuccessful.
Smooth recycling
In an effort to explain our past problems, we pointed out the discrete nature of our technique, as opposed to the continuous landscape of training/optimizers. To change this, we have modified our strategy to gradually change the value by linear interpolation. More precisely, new value assigned for a recycled weight was \[ x = \alpha \ x_{new} + (1-\alpha) \ stop\_grad(x_{old}),\] where:
- \(x_{old}\) - old value of the weight before recycling
- \(x_{new}\) - new value chosen for the weight
- \(\alpha\) - non-trainable parameter, changed linearly from 0 to 1 in the steps following reinitialization
- \(stop\_grad\) - stops the gradient from flowing through.
We have noticed that this way we were able to make training loss dynamics smoother, but without beating the baseline.
Tangent - Midpoint Loss
While working towards the main goal of the project, we started investigating the distribution of neuron magnitudes during the training. We noticed that the magnitude distribution is quite uneven - a large percentage of neurons remains small, and the distribution is right-skewed. Since the goal of our project was to reduce the number of low-quality, i.e., small neurons, we came up with a pretty risky solution: Midpoint Loss. The idea was to introduce an additional loss term that would encourage neuron growth and “even-ness” of the magnitude distribution. The general equation for the midpoint loss is:
\[ Loss = \sum_{l = 1}^{L} \sum_{n = 1}^{d_\text{ff}} dist(M_{l,n}, stop\_grad(avg_n(M_{l,n})))\] where:
- \(L\) - number of layers
- \(d_\text{ff}\) - number of neurons in a layer. In some experiments, we only summed over neurons with magnitude below the average magnitude of the layer, to encourage growth of small neurons, without thwarting the growth of the large ones
- \(dist\) - distance function, typically \(l_1\) or \(l_2\)
- \(M_{l,n}\) - magnitude of th \(n^{\text{th}}\) neuron in the \(l^{\text{th}}\) layer. In some experiments we used the \(log\) of the magnitude
- \(stop\_grad\) - stops the gradient from flowing through
- \(avg\) - arithmetic mean. In some experiments, median was used instead due to its robustness to outliers.
Since this idea is quite similar to weight decay, we decided not to optimize this term with Adam, but to split it from the task loss and optimize it using simple gradient descent - a similar technique is used in AdamW to incorporate weight decay loss term.
Midpoint loss achieved the goal of boosting the small neurons, however it failed to make a positive impact on the model’s performance.
Conclusion
In this work, we described our attempts to integrate pruning and Lottery Ticket Hypothesis via neuron recycling. Although we were not able to beat the baseline using our technique, we explored the topic thoroughly and conducted a series of experiments, providing valuable insights into the inner workings of a transformer. We hope that our findings may be a helpful resource for future studies and investigations in this area.